"""Weighted low rank approximation based on column norm sampling."""

import numpy as np

from . import weighted_linreg

def weighted_lra(matrix, weight, rank):
  norms = np.linalg.norm(matrix, axis=0)
  norms = norms * norms
  norms = norms / np.sum(norms)
  # sample rows proportionally to the col norms for left factor
  col_sample = np.random.choice(len(norms), rank, p=norms)
  left_factor = matrix[:, col_sample]
  # solve for the right factor
  right_factor = weighted_linreg.weighted_linreg(matrix, weight, left_factor)
  return left_factor, right_factor
